Skip to content

[fix] bugfix 2856: Fix pre-allocated out shape check in trtllm_batch_decode_with_kv_cache_mla for q_len_per_req > 1#2876

Merged
qsang-nv merged 2 commits intoflashinfer-ai:mainfrom
qsang-nv:issue_2856
Mar 31, 2026
Merged

[fix] bugfix 2856: Fix pre-allocated out shape check in trtllm_batch_decode_with_kv_cache_mla for q_len_per_req > 1#2876
qsang-nv merged 2 commits intoflashinfer-ai:mainfrom
qsang-nv:issue_2856

Conversation

@qsang-nv
Copy link
Copy Markdown
Collaborator

@qsang-nv qsang-nv commented Mar 24, 2026

📌 Description

This PR fixes #2856.

trtllm_batch_decode_with_kv_cache_mla rejects a correctly-shaped pre-allocated out tensor when q_len_per_req > 1 (speculative decoding / MTP). The out is None path correctly infers a 4D output shape [B, q_len, H, kv_lora_rank] via query.shape[:-1] + (kv_lora_rank,), but the out is not None path hardcodes a 3D expected shape [B, H, kv_lora_rank], missing the q_len dimension.

The fix unifies both paths to use query.shape[:-1] + (kv_lora_rank,) as the expected output shape.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

Release Notes

  • Tests

    • Added comprehensive test coverage for MLA decoding with the TensorRT-LLM backend, validating both preallocated and auto-allocated output buffer modes.
  • Refactor

    • Streamlined output validation logic for improved consistency in the MLA decoding module.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 24, 2026

📝 Walkthrough

Walkthrough

A bug fix for trtllm_batch_decode_with_kv_cache_mla that corrects output shape validation when pre-allocated out tensors are provided with q_len_per_request > 1. The validation now uses a consistent 4D shape computation across both allocation and validation paths. A test case validates both allocation modes.

Changes

Cohort / File(s) Summary
Bug Fix
flashinfer/mla/_core.py
Unified output shape computation in trtllm_batch_decode_with_kv_cache_mla by introducing expected_out_shape = query.shape[:-1] + (kv_lora_rank,) for both out is None and pre-allocated out cases, fixing rejection of correctly-shaped 4D tensors when q_len_per_request > 1.
Test Coverage
tests/attention/test_trtllm_gen_mla.py
Added parametrized test test_trtllm_batch_decode_mla_preallocated_out validating MLA decode with pre-allocated output tensors across varying batch sizes and query lengths, confirming output shape and in-place write behavior.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Possibly related PRs

Suggested labels

ready

Suggested reviewers

  • yzh119
  • bkryu
  • nvmbreughe
  • cyx-6

Poem

🐰 A shape that was broken, now fixed with care,
Four dimensions allowed, validation laid bare,
Pre-allocated tensors, no longer dismissed,
When query tokens dance beyond one—bug fixed! ✨

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main fix: addressing the pre-allocated output shape validation issue in trtllm_batch_decode_with_kv_cache_mla for q_len_per_req > 1.
Description check ✅ Passed The description provides a clear explanation of the bug, the fix, and references the linked issue #2856, meeting the template's core requirements despite missing explicit checklist completion.
Linked Issues check ✅ Passed The code changes directly address the requirements from issue #2856: unifying shape validation using query.shape[:-1] + (kv_lora_rank,) for both auto-allocation and pre-allocated output paths.
Out of Scope Changes check ✅ Passed All changes are directly scoped to fixing the identified bug and adding corresponding test coverage; no extraneous modifications are present.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical bug in the trtllm_batch_decode_with_kv_cache_mla function that previously prevented the correct use of pre-allocated output tensors during multi-token generation. By standardizing the output shape validation, the change ensures consistent and accurate handling of output buffers, improving the robustness of the system for speculative decoding and multi-token prefill scenarios.

Highlights

  • Bug Fix: Addressed a bug in the trtllm_batch_decode_with_kv_cache_mla function where pre-allocated output tensors were incorrectly rejected when q_len_per_req > 1 due to an incorrect hardcoded 3D shape expectation.
  • Shape Unification: Unified the output shape calculation logic for both out is None and out is not None paths, ensuring the correct 4D shape query.shape[:-1] + (kv_lora_rank,) is consistently used.
  • New Test Case: Introduced a new test test_trtllm_batch_decode_mla_preallocated_out to validate the fix, covering scenarios with q_len_per_req > 1 and pre-allocated output tensors.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes a bug in trtllm_batch_decode_with_kv_cache_mla where the shape check for a pre-allocated output tensor was incorrect for multi-token generation (q_len_per_req > 1). The fix unifies the output shape calculation for both the out=None and pre-allocated out cases, which also improves code clarity by removing duplication. A new test case has been added to verify the fix, ensuring that both paths produce identical results. The changes look good. I have one minor suggestion in the new test file to improve maintainability by reducing code duplication.

Comment on lines +856 to +863
global global_trtllm_gen_fmha_workspace_buffer
if global_trtllm_gen_fmha_workspace_buffer is None:
global_trtllm_gen_fmha_workspace_buffer = torch.zeros(
workspace_size,
dtype=torch.int8,
device=device,
)
workspace = global_trtllm_gen_fmha_workspace_buffer
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This workspace buffer initialization logic is duplicated from the trtllm_batch_decode_mla helper function in this file (lines 324-328). To improve maintainability and reduce code duplication, consider creating a pytest fixture to provide the workspace buffer. This would encapsulate the global variable and its initialization logic, making the tests cleaner.

For example, you could add a fixture like this:

@pytest.fixture(scope="module")
def trtllm_gen_fmha_workspace(device="cuda:0"):
    """Provides a zero-initialized workspace buffer for trtllm-gen MLA tests."""
    global global_trtllm_gen_fmha_workspace_buffer
    if global_trtllm_gen_fmha_workspace_buffer is None:
        global_trtllm_gen_fmha_workspace_buffer = torch.zeros(
            workspace_size, dtype=torch.int8, device=device
        )
    return global_trtllm_gen_fmha_workspace_buffer

And then use it in the test signature.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/attention/test_trtllm_gen_mla.py (1)

887-903: Assert the preallocated buffer is actually reused.

Right now the test verifies shape/value equivalence, but it does not guarantee zero extra allocation on the out= path. Add a pointer check so the regression guard also enforces buffer reuse.

Proposed test hardening
     result_pre = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
         query=query,
         kv_cache=kv_cache,
         workspace_buffer=workspace,
         qk_nope_head_dim=qk_nope_head_dim,
         kv_lora_rank=kv_lora_rank,
         qk_rope_head_dim=qk_rope_head_dim,
         block_tables=block_tables,
         seq_lens=seq_lens,
         max_seq_len=max_seq_len,
         out=out,
         bmm1_scale=bmm1_scale,
         bmm2_scale=1.0,
         backend="trtllm-gen",
     )
+    assert result_pre.data_ptr() == out.data_ptr(), "Expected kernel to write into provided `out` tensor"
     assert result_pre.shape == expected_shape
     torch.testing.assert_close(result_none, result_pre, rtol=1e-3, atol=1e-3)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_trtllm_gen_mla.py` around lines 887 - 903, The test
currently checks shape and values but not that the provided preallocated buffer
is reused; after calling
flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(...) with the out=
parameter, assert that the returned tensor reuses the same memory by comparing
pointers (e.g. result_pre.data_ptr() == out.data_ptr()); keep the existing
torch.testing.assert_close but add this pointer equality assertion (or capture
out_ptr before the call and compare to result_pre.data_ptr() after) to enforce
zero extra allocation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/attention/test_trtllm_gen_mla.py`:
- Around line 887-903: The test currently checks shape and values but not that
the provided preallocated buffer is reused; after calling
flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(...) with the out=
parameter, assert that the returned tensor reuses the same memory by comparing
pointers (e.g. result_pre.data_ptr() == out.data_ptr()); keep the existing
torch.testing.assert_close but add this pointer equality assertion (or capture
out_ptr before the call and compare to result_pre.data_ptr() after) to enforce
zero extra allocation.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 912c5542-2ff0-4ad2-98e9-ce6f7a04f0f8

📥 Commits

Reviewing files that changed from the base of the PR and between 19bbdd3 and abe810a.

📒 Files selected for processing (2)
  • flashinfer/mla.py
  • tests/attention/test_trtllm_gen_mla.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/attention/test_trtllm_gen_mla.py (1)

856-864: Consider zeroing shared workspace on reuse to avoid test-order coupling.

At Line 857, the global buffer is only zero-initialized on first allocation. If a prior test mutated it, this test can become order-dependent. Re-zeroing on reuse makes this test self-contained.

Proposed tweak
     global global_trtllm_gen_fmha_workspace_buffer
     if global_trtllm_gen_fmha_workspace_buffer is None:
         global_trtllm_gen_fmha_workspace_buffer = torch.zeros(
             workspace_size,
             dtype=torch.int8,
             device=device,
         )
+    else:
+        global_trtllm_gen_fmha_workspace_buffer.zero_()
     workspace = global_trtllm_gen_fmha_workspace_buffer
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_trtllm_gen_mla.py` around lines 856 - 864, The global
shared buffer global_trtllm_gen_fmha_workspace_buffer may contain leftovers from
prior tests; when reusing it (after the existing allocation check using
workspace_size and device), explicitly zero it before assigning workspace (e.g.,
call its in-place zeroing method such as zero_() or fill_(0) on
global_trtllm_gen_fmha_workspace_buffer) so the test does not depend on prior
test mutations.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/attention/test_trtllm_gen_mla.py`:
- Around line 856-864: The global shared buffer
global_trtllm_gen_fmha_workspace_buffer may contain leftovers from prior tests;
when reusing it (after the existing allocation check using workspace_size and
device), explicitly zero it before assigning workspace (e.g., call its in-place
zeroing method such as zero_() or fill_(0) on
global_trtllm_gen_fmha_workspace_buffer) so the test does not depend on prior
test mutations.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 125aae67-88e7-4de3-ab3d-021b4ba04c17

📥 Commits

Reviewing files that changed from the base of the PR and between abe810a and 94498c3.

📒 Files selected for processing (1)
  • tests/attention/test_trtllm_gen_mla.py

@qsang-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !460 has been created, and the CI pipeline #46925815 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #46925815: 13/20 passed

@yzh119 yzh119 added the run-ci label Mar 25, 2026
Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, should be ready to merge as long as all CI passed.

@qsang-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !460 has been created, and the CI pipeline #46953779 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #46953779: 9/20 passed

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
flashinfer/mla/_core.py (1)

752-752: Optional style nit: prefer tuple unpacking form for expected_out_shape.

To satisfy Ruff RUF005 and improve readability, consider (*query.shape[:-1], kv_lora_rank) instead of tuple concatenation.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/mla/_core.py` at line 752, Change the tuple construction for
expected_out_shape to use tuple-unpacking form for readability and to satisfy
Ruff RUF005: replace the concatenation expression that uses query.shape and
kv_lora_rank with a single tuple built by unpacking query.shape[:-1] and
appending kv_lora_rank (refer to expected_out_shape, query.shape, and
kv_lora_rank).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@flashinfer/mla/_core.py`:
- Line 752: Change the tuple construction for expected_out_shape to use
tuple-unpacking form for readability and to satisfy Ruff RUF005: replace the
concatenation expression that uses query.shape and kv_lora_rank with a single
tuple built by unpacking query.shape[:-1] and appending kv_lora_rank (refer to
expected_out_shape, query.shape, and kv_lora_rank).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 802834df-fc0c-451d-8bd8-caf43ee94c7c

📥 Commits

Reviewing files that changed from the base of the PR and between 94498c3 and bcaeec6.

📒 Files selected for processing (2)
  • flashinfer/mla/_core.py
  • tests/attention/test_trtllm_gen_mla.py
✅ Files skipped from review due to trivial changes (1)
  • tests/attention/test_trtllm_gen_mla.py

@qsang-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !460 has been updated with latest changes, and the CI pipeline #47236490 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47236490: 13/20 passed

@qsang-nv qsang-nv merged commit 202bd18 into flashinfer-ai:main Mar 31, 2026
29 of 30 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] trtllm_batch_decode_with_kv_cache_mla rejects pre-allocated out when q_len_per_req > 1

3 participants